In [1]:
# Install a pip package in the current Jupyter kernel
import sys
!{sys.executable} -m pip install numpy==1.16.4
Collecting numpy==1.16.4
  Using cached https://files.pythonhosted.org/packages/0f/c9/3526a357b6c35e5529158fbcfac1bb3adc8827e8809a6d254019d326d1cc/numpy-1.16.4-cp36-cp36m-macosx_10_6_intel.macosx_10_9_intel.macosx_10_9_x86_64.macosx_10_10_intel.macosx_10_10_x86_64.whl
ERROR: tensorflow 1.14.0 requires google-pasta>=0.1.6, which is not installed.
Installing collected packages: numpy
  Found existing installation: numpy 1.17.4
    Uninstalling numpy-1.17.4:
      Successfully uninstalled numpy-1.17.4
Successfully installed numpy-1.16.4
In [2]:
!conda list
# packages in environment at /anaconda3/envs/unet:
#
# Name                    Version                   Build  Channel
_tflow_select             2.3.0                       mkl  
absl-py                   0.9.0                    py36_0    conda-forge
appnope                   0.1.0                 py36_1000    conda-forge
astor                     0.7.1                      py_0    conda-forge
attrs                     19.3.0                     py_0    conda-forge
backcall                  0.1.0                      py_0    conda-forge
bleach                    3.1.0                      py_0    conda-forge
boto3                     1.11.9                   pypi_0    pypi
botocore                  1.14.9                   pypi_0    pypi
bzip2                     1.0.8                h0b31af3_2    conda-forge
c-ares                    1.15.0            h01d97ff_1001    conda-forge
ca-certificates           2020.1.1                      0    anaconda
cairo                     1.16.0            he1c11cd_1002    conda-forge
certifi                   2019.11.28               py36_0    anaconda
cycler                    0.10.0                     py_2    conda-forge
decorator                 4.4.1                      py_0    conda-forge
defusedxml                0.6.0                      py_0    conda-forge
docutils                  0.15.2                   pypi_0    pypi
entrypoints               0.3                   py36_1000    conda-forge
ffmpeg                    4.1.3                h5c2b479_0    conda-forge
fontconfig                2.13.1            h6b1039f_1001    conda-forge
freetype                  2.10.0               h24853df_1    conda-forge
fsspec                    0.6.2                    pypi_0    pypi
gast                      0.3.3                      py_0    conda-forge
gettext                   0.19.8.1          h46ab8bc_1002    conda-forge
giflib                    5.1.7                h01d97ff_1    conda-forge
glib                      2.58.3            h9d45998_1002    conda-forge
gmp                       6.1.2             h0a44026_1000    conda-forge
gnutls                    3.6.5             h53004b3_1002    conda-forge
graphite2                 1.3.13            h2098e52_1000    conda-forge
grpcio                    1.23.0           py36h6ef0057_0    conda-forge
h5py                      2.10.0          nompi_py36h106b333_102    conda-forge
harfbuzz                  2.4.0                hd8d2a14_3    conda-forge
hdf5                      1.10.5          nompi_h15a436c_1103    conda-forge
icu                       64.2                 h6de7cb9_1    conda-forge
importlib_metadata        1.4.0                    py36_0    conda-forge
ipykernel                 5.1.3            py36h5ca1d4c_0    conda-forge
ipython                   7.11.1           py36h5ca1d4c_0    conda-forge
ipython_genutils          0.2.0                      py_1    conda-forge
jasper                    1.900.1           h636a363_1006    conda-forge
jedi                      0.15.2                   py36_0    conda-forge
jinja2                    2.10.3                     py_0    conda-forge
jmespath                  0.9.4                    pypi_0    pypi
joblib                    0.14.1                   pypi_0    pypi
jpeg                      9c                h1de35cc_1001    conda-forge
jsonschema                3.2.0                    py36_0    conda-forge
jupyter_client            5.3.4                    py36_1    conda-forge
jupyter_core              4.6.1                    py36_0    conda-forge
keras                     2.2.4                    py36_1    conda-forge
keras-applications        1.0.8                      py_1    conda-forge
keras-preprocessing       1.1.0                      py_0    conda-forge
keras-unet                0.0.7                    pypi_0    pypi
kiwisolver                1.1.0            py36h770b8ee_0    conda-forge
lame                      3.100             h1de35cc_1001    conda-forge
libblas                   3.8.0               15_openblas    conda-forge
libcblas                  3.8.0               15_openblas    conda-forge
libcxx                    4.0.1                hcfea43d_1  
libcxxabi                 4.0.1                hcfea43d_1  
libedit                   3.1.20181209         hb402a30_0  
libffi                    3.2.1                h475c297_4  
libgfortran               4.0.0                         2    conda-forge
libgpuarray               0.7.6             h1de35cc_1003    conda-forge
libiconv                  1.15              h01d97ff_1005    conda-forge
liblapack                 3.8.0               15_openblas    conda-forge
liblapacke                3.8.0               15_openblas    conda-forge
libopenblas               0.3.8                h3d69b6c_0    conda-forge
libopencv                 4.1.1                hb60cc42_3    conda-forge
libpng                    1.6.37               h2573ce8_0    conda-forge
libprotobuf               3.9.2                hfbae3c0_0    conda-forge
libsodium                 1.0.17               h01d97ff_0    conda-forge
libtiff                   4.0.10            hd08fb8f_1003    conda-forge
libwebp                   1.0.2                h20df551_3    conda-forge
libxml2                   2.9.10               h53d96d6_0    conda-forge
llvm-openmp               9.0.1                h28b9765_2    conda-forge
lz4-c                     1.8.3             h6de7cb9_1001    conda-forge
mako                      1.1.0                      py_0    conda-forge
markdown                  3.2.1                      py_0    conda-forge
markupsafe                1.1.1            py36h0b31af3_0    conda-forge
matplotlib                3.1.2                    pypi_0    pypi
matplotlib-base           3.1.1            py36h3a684a6_1    conda-forge
mistune                   0.8.4           py36h0b31af3_1000    conda-forge
more-itertools            8.1.0                      py_0    conda-forge
nbconvert                 5.6.1                    py36_0    conda-forge
nbformat                  5.0.3                      py_0    conda-forge
ncurses                   6.1                  h0a44026_1  
nettle                    3.4.1             h3efe00b_1002    conda-forge
notebook                  6.0.1                    py36_0    conda-forge
numpy                     1.16.4                   pypi_0    pypi
olefile                   0.46                     py36_0    anaconda
opencv                    4.1.1                         3    conda-forge
openh264                  1.8.0             hd9629dc_1000    conda-forge
openssl                   1.1.1                h1de35cc_0    anaconda
pandas                    0.25.3                   pypi_0    pypi
pandoc                    2.9.1.1                       0    conda-forge
pandocfilters             1.4.2                      py_1    conda-forge
parso                     0.5.2                      py_0    conda-forge
pcre                      8.41              h0a44026_1003    conda-forge
pexpect                   4.7.0                    py36_0    conda-forge
pickleshare               0.7.5                 py36_1000    conda-forge
pillow                    6.2.1            py36hb68e598_0    anaconda
pip                       19.2.2                   py36_0  
pixman                    0.38.0            h01d97ff_1003    conda-forge
prometheus_client         0.7.1                      py_0    conda-forge
prompt_toolkit            3.0.2                      py_0    conda-forge
protobuf                  3.9.2            py36h6de7cb9_1    conda-forge
ptyprocess                0.6.0                   py_1001    conda-forge
py-opencv                 4.1.1            py36h5ca1d4c_3    conda-forge
pygments                  2.5.2                      py_0    conda-forge
pygpu                     0.7.6           py36h3b54f70_1000    conda-forge
pyparsing                 2.4.5                    pypi_0    pypi
pyrsistent                0.15.7           py36h0b31af3_0    conda-forge
python                    3.6.9                h359304d_0  
python-dateutil           2.8.1                      py_0    conda-forge
pytz                      2019.3                   pypi_0    pypi
pyyaml                    5.3              py36h0b31af3_0    conda-forge
pyzmq                     18.1.0           py36hee98d25_0    conda-forge
readline                  7.0                  h1de35cc_5  
s3fs                      0.4.0                    pypi_0    pypi
s3transfer                0.3.2                    pypi_0    pypi
scikit-learn              0.22.1                   pypi_0    pypi
scipy                     1.4.1                    pypi_0    pypi
send2trash                1.5.0                      py_0    conda-forge
setuptools                41.0.1                   py36_0  
six                       1.13.0                   py36_0    conda-forge
sqlite                    3.29.0               ha441bb4_0  
tensorboard               1.14.0                   py36_0    conda-forge
tensorflow                1.14.0          mkl_py36h933f829_0  
tensorflow-base           1.14.0          mkl_py36h655c25b_0  
tensorflow-estimator      1.14.0           py36h5ca1d4c_0    conda-forge
termcolor                 1.1.0                      py_2    conda-forge
terminado                 0.8.3                    py36_0    conda-forge
testpath                  0.4.4                      py_0    conda-forge
theano                    1.0.3                    py36_0    conda-forge
tk                        8.6.8                ha441bb4_0  
tornado                   6.0.3            py36h0b31af3_0    conda-forge
traitlets                 4.3.3                    py36_0    conda-forge
urllib3                   1.25.8                   pypi_0    pypi
wcwidth                   0.1.8                      py_0    conda-forge
webencodings              0.5.1                      py_1    conda-forge
werkzeug                  1.0.0                      py_0    conda-forge
wheel                     0.33.4                   py36_0  
wordcloud                 1.6.0                    pypi_0    pypi
wrapt                     1.12.0           py36h0b31af3_0    conda-forge
x264                      1!152.20180806       h1de35cc_0    conda-forge
xz                        5.2.4                h1de35cc_4  
yaml                      0.2.2                h0b31af3_1    conda-forge
zeromq                    4.3.2                h6de7cb9_2    conda-forge
zipp                      1.0.0                      py_0    conda-forge
zlib                      1.2.11               h1de35cc_3  
zstd                      1.4.0                ha9f0a20_0    conda-forge
In [2]:
import os
import numpy as np 
import cv2
from PIL import Image
from keras_unet.utils import plot_imgs, get_augmented
from sklearn.model_selection import train_test_split
from keras_unet.models import custom_unet
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.optimizers import Adam, SGD
from keras_unet.metrics import iou, iou_thresholded
from keras.callbacks import ModelCheckpoint
Using TensorFlow backend.
In [3]:
IMG_TRAIN_PATH = 'data/train_images/'
MSK_TRAIN_PATH = 'data/train_masks/'
IMG_TEST_PATH = 'data/test_images/'
MSK_TEST_PATH = 'data/test_masks/'
IMG_VAL_PATH = 'data/val_images/'
MSK_VAL_PATH = 'data/val_masks/'
In [4]:
ALL_PATHS = [IMG_TRAIN_PATH, MSK_TRAIN_PATH, IMG_TEST_PATH, 
             MSK_TEST_PATH, IMG_VAL_PATH, MSK_VAL_PATH]
In [5]:
all_training_images = [f for f in os.listdir(IMG_TRAIN_PATH) if f.endswith('.png')]
all_training_masks = [f for f in os.listdir(MSK_TRAIN_PATH) if f.endswith('.png')]
In [12]:
imgs_list = []
masks_list = []
for image, mask in zip(all_training_images, all_training_masks):
    imgs_list.append(np.array(Image.open(IMG_TRAIN_PATH + image).resize((256,256))))
    msk = cv2.imread(MSK_TRAIN_PATH + image, cv2.IMREAD_GRAYSCALE)
    msk = cv2.resize(msk, (256, 256), interpolation=cv2.INTER_NEAREST)
    masks_list.append(msk)
imgs_np = np.asarray(imgs_list)
masks_np = np.asarray(masks_list)
In [13]:
X = np.asarray(imgs_np, dtype=np.float32)/255
y = np.asarray(masks_np, dtype=np.float32)
In [14]:
plot_imgs(org_imgs=X, mask_imgs=y, nm_img_to_plot=10, figsize=6)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [15]:
y = np.where(y == 0, 0, 1)
In [16]:
plot_imgs(org_imgs=X, mask_imgs=y, nm_img_to_plot=10, figsize=6)
In [17]:
y = y.reshape(y.shape[0], y.shape[1], y.shape[2], 1)
In [18]:
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=0)
print("x_train: ", X_train.shape)
print("y_train: ", y_train.shape)
print("x_val: ", X_val.shape)
print("y_val: ", y_val.shape)
x_train:  (348, 256, 256, 3)
y_train:  (348, 256, 256, 1)
x_val:  (1395, 256, 256, 3)
y_val:  (1395, 256, 256, 1)
In [19]:
train_gen = get_augmented(
    X_train, y_train, batch_size=32,
    data_gen_args = dict(
        rotation_range=10.,
        #width_shift_range=0.02,
        height_shift_range=0.02,
        #shear_range=5.,
        zoom_range=0.2,
        horizontal_flip=True,
        vertical_flip=False,
        fill_mode='constant'
    ))
In [15]:
unique, counts = np.unique(y, return_counts=True)
In [16]:
for i in range(0,len(counts)):
    a = counts[i]/counts.sum()
    print(str(a))
0.7938619324146785
0.20613806758532144
In [20]:
INPUT_SHAPE = X_train[0].shape
BATCH_SIZE = 32
In [21]:
model = custom_unet(
    input_shape = INPUT_SHAPE,
    filters = 32,
    use_batch_norm = True,
    dropout = 0.3,
    dropout_change_per_layer=0.0,
    num_layers=4)
    #output_activation = 'softmax')
WARNING:tensorflow:From /anaconda3/envs/unet/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

WARNING:tensorflow:From /anaconda3/envs/unet/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

WARNING:tensorflow:From /anaconda3/envs/unet/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:4185: The name tf.truncated_normal is deprecated. Please use tf.random.truncated_normal instead.

WARNING:tensorflow:From /anaconda3/envs/unet/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.

WARNING:tensorflow:From /anaconda3/envs/unet/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:181: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.

WARNING:tensorflow:From /anaconda3/envs/unet/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:1834: The name tf.nn.fused_batch_norm is deprecated. Please use tf.compat.v1.nn.fused_batch_norm instead.

WARNING:tensorflow:From /anaconda3/envs/unet/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
WARNING:tensorflow:From /anaconda3/envs/unet/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3976: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.

In [27]:
model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 256, 256, 32) 896         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 256, 256, 32) 128         conv2d_1[0][0]                   
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 256, 256, 32) 0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 256, 256, 32) 9248        dropout_1[0][0]                  
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 256, 256, 32) 128         conv2d_2[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 128, 128, 32) 0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 128, 128, 64) 18496       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 128, 128, 64) 256         conv2d_3[0][0]                   
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 128, 128, 64) 0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 128, 128, 64) 36928       dropout_2[0][0]                  
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 128, 128, 64) 256         conv2d_4[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 64, 64, 64)   0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 64, 64, 128)  73856       max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 64, 64, 128)  512         conv2d_5[0][0]                   
__________________________________________________________________________________________________
dropout_3 (Dropout)             (None, 64, 64, 128)  0           batch_normalization_5[0][0]      
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 64, 64, 128)  147584      dropout_3[0][0]                  
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 64, 64, 128)  512         conv2d_6[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 32, 32, 128)  0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 32, 32, 256)  295168      max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 32, 32, 256)  1024        conv2d_7[0][0]                   
__________________________________________________________________________________________________
dropout_4 (Dropout)             (None, 32, 32, 256)  0           batch_normalization_7[0][0]      
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 32, 32, 256)  590080      dropout_4[0][0]                  
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 32, 32, 256)  1024        conv2d_8[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 16, 16, 256)  0           batch_normalization_8[0][0]      
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 16, 16, 512)  1180160     max_pooling2d_4[0][0]            
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 16, 16, 512)  2048        conv2d_9[0][0]                   
__________________________________________________________________________________________________
dropout_5 (Dropout)             (None, 16, 16, 512)  0           batch_normalization_9[0][0]      
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 16, 16, 512)  2359808     dropout_5[0][0]                  
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 16, 16, 512)  2048        conv2d_10[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 32, 32, 256)  524544      batch_normalization_10[0][0]     
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 32, 32, 512)  0           conv2d_transpose_1[0][0]         
                                                                 batch_normalization_8[0][0]      
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 32, 32, 256)  1179904     concatenate_1[0][0]              
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 32, 32, 256)  1024        conv2d_11[0][0]                  
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 32, 32, 256)  590080      batch_normalization_11[0][0]     
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 32, 32, 256)  1024        conv2d_12[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 64, 64, 128)  131200      batch_normalization_12[0][0]     
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 64, 64, 256)  0           conv2d_transpose_2[0][0]         
                                                                 batch_normalization_6[0][0]      
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 64, 64, 128)  295040      concatenate_2[0][0]              
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 64, 64, 128)  512         conv2d_13[0][0]                  
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 64, 64, 128)  147584      batch_normalization_13[0][0]     
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 64, 64, 128)  512         conv2d_14[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 128, 128, 64) 32832       batch_normalization_14[0][0]     
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 128, 128, 128 0           conv2d_transpose_3[0][0]         
                                                                 batch_normalization_4[0][0]      
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 128, 128, 64) 73792       concatenate_3[0][0]              
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 128, 128, 64) 256         conv2d_15[0][0]                  
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 128, 128, 64) 36928       batch_normalization_15[0][0]     
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 128, 128, 64) 256         conv2d_16[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_4 (Conv2DTrans (None, 256, 256, 32) 8224        batch_normalization_16[0][0]     
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 256, 256, 64) 0           conv2d_transpose_4[0][0]         
                                                                 batch_normalization_2[0][0]      
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 256, 256, 32) 18464       concatenate_4[0][0]              
__________________________________________________________________________________________________
batch_normalization_17 (BatchNo (None, 256, 256, 32) 128         conv2d_17[0][0]                  
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 256, 256, 32) 9248        batch_normalization_17[0][0]     
__________________________________________________________________________________________________
batch_normalization_18 (BatchNo (None, 256, 256, 32) 128         conv2d_18[0][0]                  
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 256, 256, 1)  33          batch_normalization_18[0][0]     
==================================================================================================
Total params: 7,771,873
Trainable params: 7,765,985
Non-trainable params: 5,888
__________________________________________________________________________________________________
In [50]:
model_filename = 'segm_model_v1.h5'
early_stopper = EarlyStopping(patience = 5, verbose = 1)
callback_checkpoint = ModelCheckpoint(
    model_filename, 
    verbose=1, 
    monitor='val_loss', 
    save_best_only=True,
)
In [53]:
def focal_loss(target, output, gamma=2):
    output /= K.sum(output, axis=-1, keepdims=True)
    eps = K.epsilon()
    output = K.clip(output, eps, 1. - eps)
    return -K.sum(K.pow(1. - output, gamma) * target * K.log(output),
                  axis=-1)
In [24]:
import keras.backend as K
import tensorflow as tf
from keras_unet.losses import jaccard_distance
In [26]:
model.compile(
    optimizer='adam', 
    loss=jaccard_distance,
    #loss = 'binary_crossentropy',
    metrics=['accuracy',iou,iou_thresholded])
In [28]:
steps_per_epoch =  len(X_train) // 32
In [30]:
history = model.fit_generator(
    train_gen,
    steps_per_epoch=steps_per_epoch,
    epochs=10,
    validation_data=(X_val, y_val),
    callbacks=[early_stopper, callback_checkpoint])
Epoch 1/10
10/10 [==============================] - 1504s 150s/step - loss: 343876.5328 - acc: 0.6104 - iou: 0.1986 - iou_thresholded: 0.2358 - val_loss: 783243.5010 - val_acc: 0.5735 - val_iou: 0.2083 - val_iou_thresholded: 0.2311

Epoch 00001: val_loss improved from inf to 783243.50103, saving model to segm_model_v1.h5
Epoch 2/10
10/10 [==============================] - 1532s 153s/step - loss: 188009.4134 - acc: 0.7504 - iou: 0.2006 - iou_thresholded: 0.2898 - val_loss: 1611262.6617 - val_acc: 0.6292 - val_iou: 0.2366 - val_iou_thresholded: 0.2862

Epoch 00002: val_loss did not improve from 783243.50103
Epoch 3/10
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-30-94244361b0ce> in <module>
      4     epochs=10,
      5     validation_data=(X_val, y_val),
----> 6     callbacks=[early_stopper, callback_checkpoint])

/anaconda3/envs/unet/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name + '` call to the ' +
     90                               'Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

/anaconda3/envs/unet/lib/python3.6/site-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1416             use_multiprocessing=use_multiprocessing,
   1417             shuffle=shuffle,
-> 1418             initial_epoch=initial_epoch)
   1419 
   1420     @interfaces.legacy_generator_methods_support

/anaconda3/envs/unet/lib/python3.6/site-packages/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
    215                 outs = model.train_on_batch(x, y,
    216                                             sample_weight=sample_weight,
--> 217                                             class_weight=class_weight)
    218 
    219                 outs = to_list(outs)

/anaconda3/envs/unet/lib/python3.6/site-packages/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight)
   1215             ins = x + y + sample_weights
   1216         self._make_train_function()
-> 1217         outputs = self.train_function(ins)
   1218         return unpack_singleton(outputs)
   1219 

/anaconda3/envs/unet/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2713                 return self._legacy_call(inputs)
   2714 
-> 2715             return self._call(inputs)
   2716         else:
   2717             if py_any(is_tensor(x) for x in inputs):

/anaconda3/envs/unet/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in _call(self, inputs)
   2673             fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
   2674         else:
-> 2675             fetched = self._callable_fn(*array_vals)
   2676         return fetched[:len(self.outputs)]
   2677 

/anaconda3/envs/unet/lib/python3.6/site-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
   1456         ret = tf_session.TF_SessionRunCallable(self._session._session,
   1457                                                self._handle, args,
-> 1458                                                run_metadata_ptr)
   1459         if run_metadata:
   1460           proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

KeyboardInterrupt: 
In [60]:
model_filename = 'segm_model_v1_c2.h5'
model.load_weights(model_filename)
# y_pred = model.predict(x_val)
In [36]:
all_testing_images = [f for f in os.listdir(IMG_TEST_PATH) if f.endswith('.png')]
all_testing_masks = [f for f in os.listdir(MSK_TEST_PATH) if f.endswith('.png')]
In [37]:
imgs_list = []
masks_list = []
for image, mask in zip(all_testing_images, all_testing_masks):
    imgs_list.append(np.array(Image.open(IMG_TEST_PATH + image).resize((256,256))))
    msk = cv2.imread(MSK_TEST_PATH + image, cv2.IMREAD_GRAYSCALE)
    msk = cv2.resize(msk, (256, 256), interpolation=cv2.INTER_NEAREST)
    masks_list.append(msk)
imgs_np = np.asarray(imgs_list)
masks_np = np.asarray(masks_list)
In [38]:
masks_np = np.where(masks_np > 0, 1, 0)
In [39]:
X_test = np.asarray(imgs_np, dtype=np.float32)/255
y_test = np.asarray(masks_np, dtype=np.float32)
In [61]:
y_pred = model.predict(X_val, verbose=1)
1395/1395 [==============================] - 613s 440ms/step
In [62]:
from keras_unet.utils import plot_imgs

plot_imgs(org_imgs=X_val, mask_imgs=y_val, pred_imgs=y_pred, nm_img_to_plot=10)